# -*- coding: utf-8 -*-
"""
Crossover analysis core (fit-based central value + robust fallback errors).

Central value:
  - Use the project's fit_string_tension(W) when there are ≥3 clean loop sizes.
  - Otherwise, fall back to an OLS slope on y=-log(mean|W|) vs area (s^2).

Uncertainty (never raises on thin data):
  - Prefer bootstrap over loop placements (if enough samples per size).
  - Also compute a size-jackknife when ≥3 sizes.
  - If both are unavailable, fall back to OLS slope SE.

Loop-size window (default):
  - 2..min(8, L//2), with guardrails to keep at least 3 sizes when possible.

Provenance:
  - Writes kernel/field RMS, g(D) range, per-size counts & mean(|W|),
    and which error sources contributed (bootstrap/jackknife/ols).
"""

import os
import pathlib
from typing import Dict, Sequence, Tuple, Optional

import numpy as np
import sys

# Repo roots on path
ROOT = pathlib.Path(__file__).resolve().parents[2]
for p in (ROOT, ROOT / "orig"):
    sp = str(p)
    if sp not in sys.path:
        sys.path.insert(0, sp)

from orig.simulation.compute_Amu import logistic_D, linear_gD  # type: ignore
from orig.simulation.compute_Umu import compute_U_from_A       # type: ignore
from orig.simulation.build_lattice import build_lattice        # type: ignore
from orig.simulation.measure_wilson import measure_wilson_loops  # type: ignore
from orig.simulation.plot_results import fit_string_tension    # type: ignore


def _ensure_dir(d: str) -> None:
    pathlib.Path(d).mkdir(parents=True, exist_ok=True)


def _central_sigma_fit(W_dict: dict) -> float:
    """Call the project fitter; normalize return to a scalar sigma."""
    out = fit_string_tension(W_dict)
    if isinstance(out, (tuple, list)):
        return float(out[0])
    return float(out)


def _clean_sizes(W: dict, sizes: Sequence[int], eps: float = 1e-14) -> Sequence[int]:
    """Keep loop sizes whose mean(|W|) is finite and > eps."""
    good = []
    for s in sizes:
        arr = np.asarray(W.get(s, []))
        if arr.size == 0:
            continue
        m = float(np.mean(np.abs(arr)))
        if np.isfinite(m) and m > eps:
            good.append(int(s))
    return good


def _ols_slope_and_se(W: dict, sizes: Sequence[int]) -> Tuple[float, float, bool]:
    """
    OLS on y=-log(mean|W_s|) vs x=s^2. Returns (slope, se, used_flag).
    Needs ≥2 sizes for slope; SE needs ≥3 sizes.
    """
    xs, ys = [], []
    for s in sizes:
        arr = np.asarray(W.get(s, []))
        if arr.size == 0:
            continue
        m = float(np.mean(np.abs(arr)))
        if not np.isfinite(m) or m <= 0:
            continue
        xs.append(float(s * s))
        ys.append(-np.log(m))

    n = len(xs)
    if n < 2:
        return 0.0, float("nan"), False

    X = np.column_stack([xs, np.ones(n)])
    y = np.asarray(ys, float)
    XtX = X.T @ X
    try:
        beta = np.linalg.solve(XtX, X.T @ y)  # [slope, intercept]
    except np.linalg.LinAlgError:
        return 0.0, float("nan"), False

    # slope as central value
    slope = float(beta[0])

    # SE for slope (only defined when n>=3)
    if n < 3:
        return slope, float("nan"), True

    yhat = X @ beta
    resid = y - yhat
    dof = max(1, n - 2)
    s2 = float((resid @ resid) / dof)
    cov = s2 * np.linalg.inv(XtX)
    se_slope = float(np.sqrt(max(0.0, cov[0, 0])))
    return slope, se_slope, True


def _bootstrap_sigma_from_W(
    W_dict: dict,
    sizes: Sequence[int],
    B: int,
    rng: np.random.Generator,
    min_sizes_with_samples: int = 2,
    min_samples_per_size: int = 3,
) -> Tuple[float, bool]:
    """
    Bootstrap SE of sigma by resampling loop placements at each size and refitting.
    Returns (bs_se, used_bootstrap_flag). If data are too thin, returns (nan, False).
    """
    eligible = [s for s in sizes if len(W_dict.get(s, [])) >= min_samples_per_size]
    if len(eligible) < min_sizes_with_samples:
        return float("nan"), False

    pools = {s: np.asarray(W_dict[s]) for s in eligible}
    sigmas = []
    B = int(max(50, B))  # keep it stable

    for _ in range(B):
        Wb = {}
        for s in eligible:
            arr = pools[s]
            idx = rng.integers(0, arr.size, size=arr.size, endpoint=False)
            Wb[s] = arr[idx]
        # Use fitter if we have >=3 sizes; else OLS central for this replicate
        if len(eligible) >= 3:
            sigmas.append(_central_sigma_fit(Wb))
        else:
            slope, _, used = _ols_slope_and_se(Wb, eligible)
            sigmas.append(float(slope if used else 0.0))

    sigmas = np.asarray(sigmas, float)
    if sigmas.size <= 1 or not np.isfinite(sigmas).all():
        return float("nan"), False
    return float(np.std(sigmas, ddof=1)), True


def _jackknife_sigma_over_sizes(W_dict: dict, sizes: Sequence[int]) -> Tuple[float, bool]:
    """Leave-one-size-out jackknife SE of sigma; returns (jk_se, used_flag)."""
    if len(sizes) < 3:
        return float("nan"), False
    thetas = []
    for i in range(len(sizes)):
        sub = [s for j, s in enumerate(sizes) if j != i]
        Wsub = {s: W_dict[s] for s in sub if s in W_dict}
        if len(sub) >= 3:
            thetas.append(_central_sigma_fit(Wsub))
        else:
            slope, _, used = _ols_slope_and_se(Wsub, sub)
            if used:
                thetas.append(float(slope))
    thetas = np.asarray(thetas, float)
    if thetas.size <= 1 or not np.isfinite(thetas).all():
        return float("nan"), False
    mu = float(thetas.mean())
    se = float(np.sqrt((len(thetas)-1)/len(thetas) * np.sum((thetas - mu)**2)))
    return se, True


def compute_sigma_c(
    *,
    L: int,
    gauge: str,
    kernel_path: str,
    flip_counts_path: str,
    pivot: Dict[str, float],
    b: float,
    k: float,
    n0: float,
    job_tag: str,
    work_dir: str,
    loop_sizes: Optional[Sequence[int]] = None,  # optional override from YAML
    bootstrap_reps: int = 600,
    bootstrap_seed: Optional[int] = 1337,
) -> Tuple[float, float]:
    """
    Return (sigma_c, sigma_c_err) using fit-based central value and robust fallback errors.
    """
    _ensure_dir(work_dir)

    L = int(L)
    gauge = str(gauge).upper()
    bc = "periodic"

    # ---- Load inputs (strict) ----
    K = np.load(kernel_path)
    fc = np.load(flip_counts_path).astype(np.float64)
    expected_links = 2 * L * L
    if fc.size != expected_links:
        raise ValueError(f"flip-counts size={fc.size} != expected 2*L*L={expected_links}")

    # ---- Build lattice and fields ----
    lattice = build_lattice(L)
    if len(lattice) != expected_links:
        raise RuntimeError(f"lattice link count mismatch: {len(lattice)} vs {expected_links}")

    D  = logistic_D(fc, k=k, n0=n0)
    gD = linear_gD(D, a=float(pivot.get("a", 1.0)), b=float(pivot.get("b", 0.0)))

    # Field A from kernel K (supports U1 and SU(N) kernels)
    if K.ndim == 1 or (K.ndim == 2 and K.shape[1] == 1):
        if K.size != expected_links:
            raise ValueError(f"U1 kernel length={K.size} != 2*L*L={expected_links}")
        A = (b * gD) * K.reshape(-1)
        A_rms = float(np.sqrt(np.mean(np.abs(A)**2)))
        K_rms = float(np.sqrt(np.mean(np.abs(K.reshape(-1))**2)))
    elif K.ndim == 3:
        if K.shape[0] != expected_links or (K.shape[1] != K.shape[2]):
            raise ValueError(f"SU(N) kernel has shape {K.shape}, expected (2*L*L, N, N)")
        A = (b * gD)[:, None, None] * K
        A_rms = float(np.sqrt(np.mean(np.abs(A)**2)))
        K_rms = float(np.sqrt(np.mean(np.abs(K)**2)))
    else:
        raise ValueError(f"Unsupported kernel shape {K.shape}")

    U = compute_U_from_A(A, gauge_group=gauge)

    # ---- Loop-size window (robust default) ----
    if loop_sizes is None:
        hi = min(8, L // 2)
        lo = 1 if L <= 8 else 2
        sizes = list(range(lo, hi + 1))
        if len(sizes) < 3:
            sizes = list(range(1, min(3, max(2, L // 2)) + 1))
    else:
        sizes = sorted(int(s) for s in loop_sizes if 1 <= int(s) <= L // 2)
        if len(sizes) < 3:
            sizes = list(range(1, min(3, max(2, L // 2)) + 1))

    # ---- Measure Wilson loops and clean sizes to avoid degenerate fits ----
    W = measure_wilson_loops(lattice, U, sizes, bc)
    sizes_fit = _clean_sizes(W, sizes, eps=1e-14)

    # ---- Central estimate ----
    if len(sizes_fit) >= 3:
        W_fit = {s: W[s] for s in sizes_fit}
        sigma = _central_sigma_fit(W_fit)
    else:
        # Fallback central value from OLS when fitter would be ill-conditioned
        sigma, _, used = _ols_slope_and_se(W, sizes_fit)
        sigma = float(sigma if used else 0.0)

    # ---- Error estimates (hierarchy: bootstrap + jackknife -> OLS fallback) ----
    rng = np.random.default_rng(bootstrap_seed if bootstrap_seed is not None else None)

    bs_se, used_bs = _bootstrap_sigma_from_W(
        {s: W[s] for s in sizes_fit}, sizes_fit, bootstrap_reps, rng
    )
    jk_se, used_jk = _jackknife_sigma_over_sizes({s: W[s] for s in sizes_fit}, sizes_fit)

    if np.isfinite(bs_se) or np.isfinite(jk_se):
        parts = [x for x in (bs_se, jk_se) if np.isfinite(x)]
        sigma_err = float(np.sqrt(np.sum(np.square(parts)))) if parts else float("nan")
        used_ols = False
        ols_se = float("nan")
    else:
        _, ols_se, used_ols = _ols_slope_and_se(W, sizes_fit)
        sigma_err = float(ols_se) if np.isfinite(ols_se) else 0.0

    # ---- Provenance & diagnostics ----
    means = {int(s): float(np.mean(np.abs(np.asarray(W[s])))) for s in sizes if s in W and len(W[s])}
    counts = {int(s): int(len(W[s])) for s in sizes if s in W}
    gD_min, gD_max = float(np.min(gD)), float(np.max(gD))
    with open(os.path.join(work_dir, f"provenance_{job_tag}.txt"), "w", encoding="utf-8") as f:
        f.write(f"L={L} gauge={gauge} b={b} k={k} n0={n0}\n")
        f.write(f"kernel={kernel_path}\nflip_counts={flip_counts_path}\n")
        f.write(f"sizes={sizes}\nclean_sizes={list(sizes_fit)}\n")
        f.write(f"A_rms={A_rms:.6e}  K_rms={K_rms:.6e}  gD_range=[{gD_min:.6e},{gD_max:.6e}]\n")
        f.write(f"counts={counts}\nmeans(|W|)={{{', '.join(f'{s}:{m:.6g}' for s,m in means.items())}}}\n")
        f.write("err_parts: ")
        f.write(f"bootstrap={'used' if used_bs else 'skip'} ")
        f.write(f"jackknife={'used' if used_jk else 'skip'} ")
        f.write(f"ols_fallback={'used' if used_ols else 'skip'}\n")
        f.write(f"sigma={sigma:.6g}  sigma_err≈{sigma_err:.6g}\n")

    # Guard: warn if everything is basically zero
    if (abs(sigma) < 1e-18) and (sigma_err < 1e-18):
        print(f"[WARN crossover] Near-degenerate estimate at L={L}, gauge={gauge}: "
              f"sigma≈0, A_rms={A_rms:.2e}, K_rms={K_rms:.2e}, gDΔ={(gD_max-gD_min):.2e}")

    return float(sigma), float(sigma_err)
